"""
Run script: functional genome wide association analysis (FGWAS) pipeline (step 1: MVCM)
Usage: python ./fgwas_step1.py ./data/ ./result/

Author: Chao Huang (chaohuang.stat@gmail.com)
Last update: 2017-12-18
"""

import sys
import os
import time
from scipy.io import loadmat, savemat
import numpy as np
from numpy.linalg import inv
from multiprocessing import Pool
from stat_read_xy import read_xy
from stat_kernel import ep_kernel
from stat_mvcm import mvcm

"""
installed all the libraries above
"""


def bw(resy_design, coord_mat, vh, h_ii):
    """
        optimal bandwidth selection in MVCM.

        Args:
            resy_design (matrix): residuals of imaging response data (n*l*m)
            coord_mat (matrix): common coordinate matrix (l*d)
            vh (matrix): candidate bandwidth matrix (nh*d)
            h_ii (scalar): candidate bandwidth index
        Outputs:
            gcv (vector): generalized cross validation index (m)
    """

    # Set up
    l, d = coord_mat.shape
    m = resy_design.shape[2]
    efit_resy = resy_design * 0
    gcv = np.zeros(m)

    w = np.zeros((1, d + 1))
    w[0] = 1

    t_mat0 = np.zeros((l, l, d + 1))  # L x L x d + 1 matrix
    t_mat0[:, :, 0] = np.ones((l, l))

    for dii in range(d):
        t_mat0[:, :, dii + 1] = np.dot(np.atleast_2d(coord_mat[:, dii]).T, np.ones((1, l))) \
                                - np.dot(np.ones((l, 1)), np.atleast_2d(coord_mat[:, dii]))

    t_mat = np.transpose(t_mat0, [0, 2, 1])  # L x d+1 x L matrix

    k_mat = np.ones((l, l))

    for dii in range(d):
        h = vh[h_ii, dii]
        k_mat = k_mat * ep_kernel(t_mat0[:, :, dii + 1] / h, h)  # Epanechnikov kernel smoothing function

        for mii in range(m):
            for lii in range(l):
                kx = np.dot(np.atleast_2d(k_mat[:, lii]).T, np.ones((1, d + 1)))*t_mat[:, :, lii]  # L0 x d+1 matrix
                sm_weight = np.dot(np.dot(w, inv(np.dot(kx.T, t_mat[:, :, lii])+np.eye(d+1)*0.0001)), kx.T)
                efit_resy[:, lii, mii] = np.squeeze(np.dot(resy_design[:, :, mii], sm_weight.T))

            gcv[mii] = np.sum((efit_resy[:, :, mii]-resy_design[:, :, mii])**2)

    return gcv


def run_step1(input_dir, output_dir):

    """
    Run the commandline script for FGWAS (step 1: MVCM).

    :param
        input_dir (str): full path to the data folder
        output_dir (str): full path to the output folder
    """

    """+++++++++++++++++++++++++++++++++++"""
    print(""" Step 0. load dataset """)
    print("+++++++Read the imaging data+++++++")
    img_file_name = input_dir + "img_data.mat"
    mat = loadmat(img_file_name)
    img_data = mat['img_data']
    if len(img_data.shape) == 2:
        img_data = img_data.reshape(img_data.shape[0], img_data.shape[1], 1)
    n, l, m = img_data.shape
    print("The matrix dimension of image data is " + str(img_data.shape))
    print("+++++++Read the imaging coordinate data+++++++")
    coord_file_name = input_dir + "coord_data.txt"
    coord_data = np.loadtxt(coord_file_name)
    print("The matrix dimension of coordinate data is " + str(coord_data.shape))
    print("+++++++Read the covariate data+++++++")
    design_data_file_name = input_dir + "design_data.txt"
    design_data = np.loadtxt(design_data_file_name)
    print("The matrix dimension of covariate data is " + str(design_data.shape))

    # read the covariate type
    var_type_file_name = input_dir + "var_type.txt"
    var_type = np.loadtxt(var_type_file_name)
    var_type = np.array([int(i) for i in var_type])

    print("+++++++++Matrix preparing and Data preprocessing++++++++")
    print("+++++++Construct the imaging response, design, coordinate matrix: normalization+++++++")
    y_design, x_design, coord_data = read_xy(coord_data, img_data, design_data, var_type)
    d = coord_data.shape[1]
    p = x_design.shape[1]
    print("The dimension of normalized design matrix is " + str(x_design.shape))

    """+++++++++++++++++++++++++++++++++++"""
    print(""" Step 1. fit the multivariate varying coefficient model (MVCM) under H0 """)
    start_1 = time.time()

    """ find the optimal bandwidth """
    resy_design = y_design * 0
    nh = 20
    vh = np.zeros((nh, d))
    for dii in range(d):
        coord_range = np.ptp(coord_data[:, dii])
        h_min = 0.01  # minimum bandwidth
        h_max = 0.6 * coord_range  # maximum bandwidth
        vh[:, dii] = np.logspace(np.log10(h_min), np.log10(h_max), nh)  # candidate bandwidth
    gcv = np.zeros((nh, m))
    h_opt = np.zeros((d, m))

    # calculate the hat matrix
    hat_mat = np.dot(np.dot(x_design, inv(np.dot(x_design.T, x_design) + np.eye(p) * 0.0001)), x_design.T)

    for mii in range(m):
        resy_design[:, :, mii] = np.dot(np.eye(n) - hat_mat, y_design[:, :, mii])

    pool = Pool()
    result = [pool.apply(bw, args=(resy_design, coord_data, vh, h_ii)) for h_ii in np.arange(nh)]
    for ii in range(nh):
        gcv[ii, :] = result[ii]
    h_opt_idx = np.argmin(gcv, axis=0)

    for dii in range(d):
        for mii in range(m):
            h_opt[dii, mii] = vh[h_opt_idx[mii], dii]
    print("the optimal bandwidth is ", h_opt)

    """ fit MVCM """
    qr_smy_mat, inv_sig_eta, smy_design, resy_design, efit_eta = mvcm(coord_data, y_design, h_opt, hat_mat)
    for mii in range(m):
        res_mii = resy_design[:, :, mii]-efit_eta[:, :, mii]
        print("The bound of the residual is [" + str(np.min(res_mii)) + ", " + str(np.max(res_mii)) + "]")
    end_1 = time.time()
    print("Elapsed time in Step 1 is ", end_1 - start_1)

    temp_dir = output_dir + "temp/"
    x_design_file_name = temp_dir + "x_design.mat"
    savemat(x_design_file_name, mdict={'x_design': x_design})
    coord_data_file_name = temp_dir + "coord_data.mat"
    savemat(coord_data_file_name, mdict={'coord_data': coord_data})
    y_design_file_name = temp_dir + "y_design.mat"
    savemat(y_design_file_name, mdict={'y_design': y_design})
    resy_design_file_name = temp_dir + "resy_design.mat"
    savemat(resy_design_file_name, mdict={'resy_design': resy_design})
    smy_design_file_name = temp_dir + "smy_design.mat"
    savemat(smy_design_file_name, mdict={'smy_design': smy_design})
    qr_smy_mat_file_name = temp_dir + "qr_smy_mat.mat"
    savemat(qr_smy_mat_file_name, mdict={'qr_smy_mat': qr_smy_mat})
    efit_eta_file_name = temp_dir + "efit_eta.mat"
    savemat(efit_eta_file_name, mdict={'efit_eta': efit_eta})
    inv_sig_eta_file_name = temp_dir + "inv_sig_eta.mat"
    savemat(inv_sig_eta_file_name, mdict={'inv_sig_eta': inv_sig_eta})
    hat_mat_file_name = temp_dir + "hat_mat.mat"
    savemat(hat_mat_file_name, mdict={'hat_mat': hat_mat})


if __name__ == '__main__':
    input_dir0 = sys.argv[1]
    output_dir0 = sys.argv[2]
    run_step1(input_dir0, output_dir0)
